{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Batch Inference with Structural Outputs (Guided Decoding)\n",
    "\n",
    "Structural output (or named guided decoding, JSON mode) is a useful feature that ensures the LLM responses following the given output schema in either JSON or the context free grammar.\n",
    "\n",
    "In this example, we show how to perform batch inference using Ray Data LLM with structural outputs in JSON format. To run this example, we need to install the following dependencies:\n",
    "\n",
    "```bash\n",
    "pip install -qU \"ray[data]\" \"vllm==0.7.2\" \"xgrammar==0.1.11\"\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pydantic import BaseModel\n",
    "\n",
    "import ray\n",
    "from ray.data.llm import build_llm_processor, vLLMEngineProcessorConfig\n",
    "\n",
    "# 1. Construct a guided decoding schema. It can be:\n",
    "# choice: List[str]\n",
    "# json: str\n",
    "# grammar: str\n",
    "# See https://docs.vllm.ai/en/latest/getting_started/examples/structured_outputs.html\n",
    "# for more details about how to construct the schema. Here we use JSON as an example.\n",
    "class AnswerWithExplain(BaseModel):\n",
    "    problem: str\n",
    "    answer: int\n",
    "    explain: str\n",
    "\n",
    "json_schema = AnswerWithExplain.model_json_schema()\n",
    "\n",
    "# 2. construct a vLLM processor config.\n",
    "processor_config = vLLMEngineProcessorConfig(\n",
    "    # The base model.\n",
    "    model_source=\"unsloth/Llama-3.2-1B-Instruct\",\n",
    "    # vLLM engine config.\n",
    "    engine_kwargs=dict(\n",
    "        # Specify the guided decoding library to use. The default is \"xgrammar\".\n",
    "        # See https://docs.vllm.ai/en/latest/serving/engine_args.html\n",
    "        # for other available libraries.\n",
    "        guided_decoding_backend=\"xgrammar\",\n",
    "        # Older GPUs (e.g. T4) don't support bfloat16. You should remove\n",
    "        # this line if you're using later GPUs.\n",
    "        dtype=\"half\",\n",
    "        # Reduce the model length to fit small GPUs. You should remove\n",
    "        # this line if you're using large GPUs.\n",
    "        max_model_len=1024,\n",
    "    ),\n",
    "    # The batch size used in Ray Data.\n",
    "    batch_size=16,\n",
    "    # Use one GPU in this example.\n",
    "    concurrency=1,\n",
    ")\n",
    "\n",
    "# 3. construct a processor using the processor config.\n",
    "processor = build_llm_processor(\n",
    "    processor_config,\n",
    "    # Convert the input data to the OpenAI chat form.\n",
    "    preprocess=lambda row: dict(\n",
    "        messages=[\n",
    "            {\n",
    "                \"role\": \"system\",\n",
    "                \"content\": \"You are a math teacher. Give the answer to \"\n",
    "                \"the equation and explain it. Output the problem, answer and \"\n",
    "                \"explanation in JSON\",\n",
    "            },\n",
    "            {\n",
    "                \"role\": \"user\",\n",
    "                \"content\": f\"3 * {row['id']} + 5 = ?\",\n",
    "            },\n",
    "        ],\n",
    "        sampling_params=dict(\n",
    "            temperature=0.3,\n",
    "            max_tokens=150,\n",
    "            detokenize=False,\n",
    "            # Specify the guided decoding schema.\n",
    "            guided_decoding=dict(json=json_schema),\n",
    "        ),\n",
    "    ),\n",
    "    # Only keep the generated text in the output dataset.\n",
    "    postprocess=lambda row: {\n",
    "        \"resp\": row[\"generated_text\"],\n",
    "    },\n",
    ")\n",
    "\n",
    "# 4. Synthesize a dataset with 30 rows.\n",
    "# Each row has a single column \"id\" ranging from 0 to 29.\n",
    "ds = ray.data.range(30)\n",
    "# 5. Apply the processor to the dataset. Note that this line won't kick off\n",
    "# anything because processor is execution lazily.\n",
    "ds = processor(ds)\n",
    "# Materialization kicks off the pipeline execution.\n",
    "ds = ds.materialize()\n",
    "\n",
    "# 6. Print all outputs.\n",
    "# Example output:\n",
    "# {\n",
    "#     \"problem\": \"3 * 6 + 5 = ?\",\n",
    "#     \"answer\": 23,\n",
    "#     \"explain\": \"To solve this equation, we need to follow the order of\n",
    "#       operations (PEMDAS): Parentheses, Exponents, Multiplication and Division,\n",
    "#       and Addition and Subtraction. In this case, we first multiply 3 and 6,\n",
    "#       which equals 18. Then we add 5 to 18, which equals 23.\"\n",
    "# }\n",
    "for out in ds.take_all():\n",
    "    print(out[\"resp\"])\n",
    "    print(\"==========\")\n",
    "\n",
    "# 7. Shutdown Ray to release resources.\n",
    "ray.shutdown()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  },
  "orphan": true
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
